package com.example.socket.filter;

import com.example.socket.core.Session;
import com.example.socket.filter.session.NettySessionManager;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.EventExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.SocketAddress;
import java.util.Date;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 清理空闲连接过滤器
 */
@Sharable
public class CleanIdleFilter extends ChannelDuplexHandler implements FilterHandler {

    private static final AttributeKey<Future<?>> TIMEOUT_FUTURE = AttributeKey.valueOf("timeFuture");
    private static final AttributeKey<Boolean> STATE = AttributeKey.valueOf("state");

    private Logger logger = LoggerFactory.getLogger(getClass());

    /** 超时时间毫秒 */
    private long timeoutMillis;
    /**最大违规次数*/
    private int maxInvalidTimes = 10;

    public CleanIdleFilter() {
    }

    public CleanIdleFilter(int timeoutSeconds) {
        this.timeoutMillis = timeoutSeconds * 1000;
    }

    /**
     * 设置线程空闲超时时间
     *
     * @param timeoutSeconds
     */
    public void setTimeoutSeconds(int timeoutSeconds) {
        this.timeoutMillis = timeoutSeconds * 1000;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        Channel channel = ctx.channel();
        // 更新时间
        updateTime(channel);
        // 提交定时任务
        initialize(ctx);
        super.channelActive(ctx);
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
            throws Exception {
        Channel channel = ctx.channel();
        // 更新时间
        updateTime(channel);
        // 不用提交定时任务
        super.channelRead(ctx, msg);
    }

    protected void updateTime(Channel channel) {
        // 上次访问时间
        Attribute<Long> attrTime = channel.attr(Session.ACCESS_TIME);
        attrTime.set(System.currentTimeMillis());
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        try {
            destory(ctx);
            super.channelInactive(ctx);
        } catch (RejectedExecutionException e) {
            logger.warn("java.util.concurrent.RejectedExecutionException: event executor terminated");
        }
    }

    private void destory(ChannelHandlerContext ctx) {
        Channel channel = ctx.channel();
        Attribute<Future<?>> attr = channel.attr(TIMEOUT_FUTURE);
        Future<?> future = attr.get();
        attr.set(null);
        if (future != null && !future.isCancelled() && !future.isDone()) {
            future.cancel(true);
        }
        Attribute<Long> acc = channel.attr(Session.ACCESS_TIME);
        acc.set(null);

        Attribute<Boolean> init = channel.attr(STATE);
        init.set(null);
    }

    @Override
    public int getOrder() {
        return 50;
    }

    @Override
    public String getName() {
        return "idleFilter";
    }

    private final class ReadTimeoutTask implements Runnable {
        private ChannelHandlerContext ctx;

        ReadTimeoutTask(ChannelHandlerContext ctx) {
            this.ctx = ctx;
        }

        @Override
        public void run() {
            Channel channel = ctx.channel();
            SocketAddress address = channel.remoteAddress();

            Session session = NettySessionManager.lookup(channel);
            if (session != null) {
                address = session.getRemoteAddress();
            }
            if (!channel.isOpen()) {
                logger.debug("连接[{}]已关闭", address);
                channel.close();
                return;
            }
            Attribute<AtomicInteger> invalidAttribute = channel.attr(Session.INVALID_TIMES);
            AtomicInteger invalidTimes = invalidAttribute.get();
            if (invalidTimes == null) {
                invalidTimes = new AtomicInteger();
                invalidAttribute.set(invalidTimes);
            }
            if (invalidTimes.get() > maxInvalidTimes) {
                logger.debug("连接[{}]错误次数[{}]过多被强制关闭, ", address);
                try {
                    channel.close();
                    return;
                } catch (Throwable t) {
                    ctx.fireExceptionCaught(t);
                }
            }

            // 上次访问时间
            Attribute<Long> attrTime = channel.attr(Session.ACCESS_TIME);
            Long lastReadTime = attrTime.get();
            if (lastReadTime == null) {
                lastReadTime = 0L;
            }

            long currentTime = System.currentTimeMillis();
            long nextDelay = timeoutMillis - (currentTime - lastReadTime);
            if (nextDelay <= 0) {
                logger.debug("连接[{}]最后访问[{}]超时关闭...", address, new Date(lastReadTime));
                try {
                    channel.close();
                } catch (Throwable t) {
                    ctx.fireExceptionCaught(t);
                }
            } else {
                logger.debug("连接[{}]超时时间更新...", address);
                Attribute<Future<?>> attrFuture = channel.attr(TIMEOUT_FUTURE);
                Future<?> future = ctx.executor().schedule(this, nextDelay, TimeUnit.MILLISECONDS);
                attrFuture.set(future);
            }
        }
    }

    private void initialize(ChannelHandlerContext ctx) {
        Channel channel = ctx.channel();
        Attribute<Boolean> initialized = channel.attr(STATE);
        Boolean init = initialized.get();
        if (init == null) {
            init = false;
        }
        if (init) {
            return;
        }
        initialized.set(true);
        Attribute<Boolean> attr = channel.attr(Session.MANAGEMENT_KEY);
        Boolean mg = attr.get();
        if (mg != null && mg.booleanValue()) {
            logger.debug("来自[{}]的管理连接,不做session定时清理", ctx.channel().remoteAddress());
            return;
        }
        EventExecutor loop = ctx.executor();
        Attribute<Future<?>> attrFuture = channel.attr(TIMEOUT_FUTURE);
        Future<?> future = attrFuture.get();
        if (timeoutMillis > 0) {
            future = loop.schedule(new ReadTimeoutTask(ctx), timeoutMillis, TimeUnit.MILLISECONDS);
        }
        attrFuture.set(future);
    }

    public void setMaxInvalidTimes(int maxInvalidTimes) {
        this.maxInvalidTimes = maxInvalidTimes;
    }
}
